{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524620c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.nbeatsx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15392f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12fa25a4",
   "metadata": {},
   "source": [
    "# NBEATSx"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1822c9e8",
   "metadata": {},
   "source": [
    "The Neural Basis Expansion Analysis (`NBEATS`) is an `MLP`-based deep neural architecture with backward and forward residual links. The network has two variants: (1) in its interpretable configuration, `NBEATS` sequentially projects the signal into polynomials and harmonic basis to learn trend and seasonality components; (2) in its generic configuration, it substitutes the polynomial and harmonic basis for identity basis and larger network's depth. The Neural Basis Expansion Analysis with Exogenous (`NBEATSx`), incorporates projections to exogenous temporal variables available at the time of the prediction.<br><br> This method proved state-of-the-art performance on the M3, M4, and Tourism Competition datasets, improving accuracy by 3% over the `ESRNN` M4 competition winner. For Electricity Price Forecasting tasks `NBEATSx` model improved accuracy by 20% and 5% over `ESRNN` and `NBEATS`, and 5% on task-specialized architectures.<br><br>**References**<br>-[Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). \"N-BEATS: Neural basis expansion analysis for interpretable time series forecasting\".](https://arxiv.org/abs/1905.10437)<br>-[Kin G. Olivares, Cristian Challu, Grzegorz Marcjasz, Rafał Weron, Artur Dubrawski (2021). \"Neural basis expansion analysis with exogenous variables: Forecasting electricity prices with NBEATSx\".](https://arxiv.org/abs/2104.05522)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bddd17a6",
   "metadata": {},
   "source": [
    "![Figure 1. Neural Basis Expansion Analysis with Exogenous Variables.](imgs_models/nbeatsx.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25e1718a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "\n",
    "from fastcore.test import test_eq, test_fail\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.utils import generate_series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "262e6ab4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from typing import Tuple, Optional\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.common._base_windows import BaseWindows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "621dd3a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b7a9fae-2c29-47e2-874e-ca1f20bf7040",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class IdentityBasis(nn.Module):\n",
    "    def __init__(self, backcast_size: int, forecast_size: int, out_features: int = 1):\n",
    "        super().__init__()\n",
    "        self.out_features = out_features\n",
    "        self.forecast_size = forecast_size\n",
    "        self.backcast_size = backcast_size\n",
    "\n",
    "    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        backcast = theta[:, : self.backcast_size]\n",
    "        forecast = theta[:, self.backcast_size :]\n",
    "        forecast = forecast.reshape(len(forecast), -1, self.out_features)\n",
    "        return backcast, forecast\n",
    "\n",
    "\n",
    "class TrendBasis(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        degree_of_polynomial: int,\n",
    "        backcast_size: int,\n",
    "        forecast_size: int,\n",
    "        out_features: int = 1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.out_features = out_features\n",
    "        polynomial_size = degree_of_polynomial + 1\n",
    "        self.backcast_basis = nn.Parameter(\n",
    "            torch.tensor(\n",
    "                np.concatenate(\n",
    "                    [\n",
    "                        np.power(\n",
    "                            np.arange(backcast_size, dtype=float) / backcast_size, i\n",
    "                        )[None, :]\n",
    "                        for i in range(polynomial_size)\n",
    "                    ]\n",
    "                ),\n",
    "                dtype=torch.float32,\n",
    "            ),\n",
    "            requires_grad=False,\n",
    "        )\n",
    "        self.forecast_basis = nn.Parameter(\n",
    "            torch.tensor(\n",
    "                np.concatenate(\n",
    "                    [\n",
    "                        np.power(\n",
    "                            np.arange(forecast_size, dtype=float) / forecast_size, i\n",
    "                        )[None, :]\n",
    "                        for i in range(polynomial_size)\n",
    "                    ]\n",
    "                ),\n",
    "                dtype=torch.float32,\n",
    "            ),\n",
    "            requires_grad=False,\n",
    "        )\n",
    "\n",
    "    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        polynomial_size = self.forecast_basis.shape[0]  # [polynomial_size, L+H]\n",
    "        backcast_theta = theta[:, :polynomial_size]\n",
    "        forecast_theta = theta[:, polynomial_size:]\n",
    "        forecast_theta = forecast_theta.reshape(\n",
    "            len(forecast_theta), polynomial_size, -1\n",
    "        )\n",
    "        backcast = torch.einsum(\"bp,pt->bt\", backcast_theta, self.backcast_basis)\n",
    "        forecast = torch.einsum(\"bpq,pt->btq\", forecast_theta, self.forecast_basis)\n",
    "        return backcast, forecast\n",
    "\n",
    "class ExogenousBasis(nn.Module):\n",
    "    # Reference: https://github.com/cchallu/nbeatsx\n",
    "    def __init__(self, forecast_size: int):\n",
    "        super().__init__()\n",
    "        self.forecast_size = forecast_size\n",
    "\n",
    "    def forward(self, theta: torch.Tensor, futr_exog: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        backcast_basis = futr_exog[:, :-self.forecast_size, :].permute(0, 2, 1)\n",
    "        forecast_basis = futr_exog[:, -self.forecast_size:, :].permute(0, 2, 1)\n",
    "        cut_point = forecast_basis.shape[1]\n",
    "        backcast_theta=theta[:, cut_point:]\n",
    "        forecast_theta=theta[:, :cut_point].reshape(\n",
    "            len(theta), cut_point, -1\n",
    "        )\n",
    "     \n",
    "        backcast = torch.einsum('bp,bpt->bt', backcast_theta, backcast_basis)\n",
    "        forecast = torch.einsum('bpq,bpt->btq', forecast_theta, forecast_basis)\n",
    "        \n",
    "        return backcast, forecast\n",
    "\n",
    "class SeasonalityBasis(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        harmonics: int,\n",
    "        backcast_size: int,\n",
    "        forecast_size: int,\n",
    "        out_features: int = 1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.out_features = out_features\n",
    "        frequency = np.append(\n",
    "            np.zeros(1, dtype=float),\n",
    "            np.arange(harmonics, harmonics / 2 * forecast_size, dtype=float)\n",
    "            / harmonics,\n",
    "        )[None, :]\n",
    "        backcast_grid = (\n",
    "            -2\n",
    "            * np.pi\n",
    "            * (np.arange(backcast_size, dtype=float)[:, None] / forecast_size)\n",
    "            * frequency\n",
    "        )\n",
    "        forecast_grid = (\n",
    "            2\n",
    "            * np.pi\n",
    "            * (np.arange(forecast_size, dtype=float)[:, None] / forecast_size)\n",
    "            * frequency\n",
    "        )\n",
    "\n",
    "        backcast_cos_template = torch.tensor(\n",
    "            np.transpose(np.cos(backcast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        backcast_sin_template = torch.tensor(\n",
    "            np.transpose(np.sin(backcast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        backcast_template = torch.cat(\n",
    "            [backcast_cos_template, backcast_sin_template], dim=0\n",
    "        )\n",
    "\n",
    "        forecast_cos_template = torch.tensor(\n",
    "            np.transpose(np.cos(forecast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        forecast_sin_template = torch.tensor(\n",
    "            np.transpose(np.sin(forecast_grid)), dtype=torch.float32\n",
    "        )\n",
    "        forecast_template = torch.cat(\n",
    "            [forecast_cos_template, forecast_sin_template], dim=0\n",
    "        )\n",
    "\n",
    "        self.backcast_basis = nn.Parameter(backcast_template, requires_grad=False)\n",
    "        self.forecast_basis = nn.Parameter(forecast_template, requires_grad=False)\n",
    "\n",
    "    def forward(self, theta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        harmonic_size = self.forecast_basis.shape[0]  # [harmonic_size, L+H]\n",
    "        backcast_theta = theta[:, :harmonic_size]\n",
    "        forecast_theta = theta[:, harmonic_size:]\n",
    "        forecast_theta = forecast_theta.reshape(len(forecast_theta), harmonic_size, -1)\n",
    "        backcast = torch.einsum(\"bp,pt->bt\", backcast_theta, self.backcast_basis)\n",
    "        forecast = torch.einsum(\"bpq,pt->btq\", forecast_theta, self.forecast_basis)\n",
    "        return backcast, forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17382790-7d84-4a89-959b-5676afa46392",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "ACTIVATIONS = [\"ReLU\", \"Softplus\", \"Tanh\", \"SELU\", \"LeakyReLU\", \"PReLU\", \"Sigmoid\"]\n",
    "\n",
    "\n",
    "class NBEATSBlock(nn.Module):\n",
    "    \"\"\"\n",
    "    N-BEATS block which takes a basis function as an argument.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_size: int,\n",
    "        h: int,\n",
    "        futr_input_size: int,\n",
    "        hist_input_size: int,\n",
    "        stat_input_size: int,\n",
    "        n_theta: int,\n",
    "        mlp_units: list,\n",
    "        basis: nn.Module,\n",
    "        dropout_prob: float,\n",
    "        activation: str,\n",
    "    ):\n",
    "        \"\"\" \"\"\"\n",
    "        super().__init__()\n",
    "\n",
    "        self.dropout_prob = dropout_prob\n",
    "        self.futr_input_size = futr_input_size\n",
    "        self.hist_input_size = hist_input_size\n",
    "        self.stat_input_size = stat_input_size\n",
    "\n",
    "        assert activation in ACTIVATIONS, f\"{activation} is not in {ACTIVATIONS}\"\n",
    "        activ = getattr(nn, activation)()\n",
    "\n",
    "        # Input vector for the block is\n",
    "        # y_lags (input_size) + historical exogenous (hist_input_size*input_size) +\n",
    "        # future exogenous (futr_input_size*input_size) + static exogenous (stat_input_size)\n",
    "        # [ Y_[t-L:t], X_[t-L:t], F_[t-L:t+H], S ]\n",
    "        input_size = (\n",
    "            input_size\n",
    "            + hist_input_size * input_size\n",
    "            + futr_input_size * (input_size + h)\n",
    "            + stat_input_size\n",
    "        )\n",
    "\n",
    "        hidden_layers = [\n",
    "            nn.Linear(in_features=input_size, out_features=mlp_units[0][0])\n",
    "        ]\n",
    "        for layer in mlp_units:\n",
    "            hidden_layers.append(nn.Linear(in_features=layer[0], out_features=layer[1]))\n",
    "            hidden_layers.append(activ)\n",
    "\n",
    "            if self.dropout_prob > 0:\n",
    "                hidden_layers.append(nn.Dropout(p=self.dropout_prob))\n",
    "\n",
    "        output_layer = [nn.Linear(in_features=mlp_units[-1][1], out_features=n_theta)]\n",
    "        layers = hidden_layers + output_layer\n",
    "        self.layers = nn.Sequential(*layers)\n",
    "        self.basis = basis\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        insample_y: torch.Tensor,\n",
    "        futr_exog: torch.Tensor,\n",
    "        hist_exog: torch.Tensor,\n",
    "        stat_exog: torch.Tensor,\n",
    "    ) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        # Flatten MLP inputs [B, L+H, C] -> [B, (L+H)*C]\n",
    "        # Contatenate [ Y_t, | X_{t-L},..., X_{t} | F_{t-L},..., F_{t+H} | S ]\n",
    "        batch_size = len(insample_y)\n",
    "        if self.hist_input_size > 0:\n",
    "            insample_y = torch.cat(\n",
    "                (insample_y, hist_exog.reshape(batch_size, -1)), dim=1\n",
    "            )\n",
    "\n",
    "        if self.futr_input_size > 0:\n",
    "            insample_y = torch.cat(\n",
    "                (insample_y, futr_exog.reshape(batch_size, -1)), dim=1\n",
    "            )\n",
    "\n",
    "        if self.stat_input_size > 0:\n",
    "            insample_y = torch.cat(\n",
    "                (insample_y, stat_exog.reshape(batch_size, -1)), dim=1\n",
    "            )\n",
    "\n",
    "        # Compute local projection weights and projection\n",
    "        theta = self.layers(insample_y)\n",
    "\n",
    "        if isinstance(self.basis, ExogenousBasis):\n",
    "            if self.futr_input_size > 0 and self.stat_input_size > 0:                \n",
    "                futr_exog = torch.cat(\n",
    "                    (\n",
    "                        futr_exog,\n",
    "                        stat_exog\n",
    "                    ),\n",
    "                    dim=2\n",
    "                )\n",
    "            elif self.futr_input_size >0:\n",
    "                futr_exog = futr_exog\n",
    "            elif self.stat_input_size >0:\n",
    "                futr_exog = stat_exog\n",
    "            else:\n",
    "                raise(ValueError(\"No stats or future exogenous. ExogenousBlock not supported.\"))    \n",
    "            backcast, forecast = self.basis(theta, futr_exog)\n",
    "            return backcast, forecast\n",
    "        else:\n",
    "            backcast, forecast = self.basis(theta)\n",
    "            return backcast, forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be997aeb-778f-442d-a97a-ff47de2deab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class NBEATSx(BaseWindows):\n",
    "    \"\"\"NBEATSx\n",
    "\n",
    "    The Neural Basis Expansion Analysis with Exogenous variables (NBEATSx) is a simple\n",
    "    and effective deep learning architecture. It is built with a deep stack of MLPs with\n",
    "    doubly residual connections. The NBEATSx architecture includes additional exogenous\n",
    "    blocks, extending NBEATS capabilities and interpretability. With its interpretable\n",
    "    version, NBEATSx decomposes its predictions on seasonality, trend, and exogenous effects.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.<br>\n",
    "    `n_harmonics`: int, Number of harmonic oscillations in the SeasonalityBasis [cos(i * t/n_harmonics), sin(i * t/n_harmonics)]. Note that it will only be used if 'seasonality' is in `stack_types`.<br>\n",
    "    `n_polynomials`: int, Number of polynomial terms for TrendBasis [1,t,...,t^n_poly]. Note that it will only be used if 'trend' is in `stack_types`.<br>\n",
    "    `stack_types`: List[str], List of stack types. Subset from ['seasonality', 'trend', 'identity'].<br>\n",
    "    `n_blocks`: List[int], Number of blocks for each stack. Note that len(n_blocks) = len(stack_types).<br>\n",
    "    `mlp_units`: List[List[int]], Structure of hidden layers for each stack type. Each internal list should contain the number of units of each hidden layer. Note that len(n_hidden) = len(stack_types).<br>\n",
    "    `dropout_prob_theta`: float, Float between (0, 1). Dropout for N-BEATS basis.<br>\n",
    "    `activation`: str, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid'].<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=3, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random seed initialization for replicability.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>\n",
    "    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>\n",
    "    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).<br>\n",
    "    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
    "\n",
    "    **References:**<br>\n",
    "    -[Kin G. Olivares, Cristian Challu, Grzegorz Marcjasz, Rafał Weron, Artur Dubrawski (2021).\n",
    "    \"Neural basis expansion analysis with exogenous variables: Forecasting electricity prices with NBEATSx\".](https://arxiv.org/abs/2104.05522)\n",
    "    \"\"\"\n",
    "\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = \"windows\"\n",
    "    EXOGENOUS_FUTR = True\n",
    "    EXOGENOUS_HIST = True\n",
    "    EXOGENOUS_STAT = True\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        h,\n",
    "        input_size,\n",
    "        futr_exog_list=None,\n",
    "        hist_exog_list=None,\n",
    "        stat_exog_list=None,\n",
    "        exclude_insample_y=False,\n",
    "        n_harmonics=2,\n",
    "        n_polynomials=2,\n",
    "        stack_types: list = [\"identity\", \"trend\", \"seasonality\"],\n",
    "        n_blocks: list = [1, 1, 1],\n",
    "        mlp_units: list = 3 * [[512, 512]],\n",
    "        dropout_prob_theta=0.0,\n",
    "        activation=\"ReLU\",\n",
    "        shared_weights=False,\n",
    "        loss=MAE(),\n",
    "        valid_loss=None,\n",
    "        max_steps: int = 1000,\n",
    "        learning_rate: float = 1e-3,\n",
    "        num_lr_decays: int = 3,\n",
    "        early_stop_patience_steps: int = -1,\n",
    "        val_check_steps: int = 100,\n",
    "        batch_size=32,\n",
    "        valid_batch_size: Optional[int] = None,\n",
    "        windows_batch_size: int = 1024,\n",
    "        inference_windows_batch_size: int = -1,\n",
    "        start_padding_enabled: bool = False,\n",
    "        step_size: int = 1,\n",
    "        scaler_type: str = \"identity\",\n",
    "        random_seed: int = 1,\n",
    "        num_workers_loader: int = 0,\n",
    "        drop_last_loader: bool = False,\n",
    "        optimizer = None,\n",
    "        optimizer_kwargs = None,\n",
    "        lr_scheduler = None,\n",
    "        lr_scheduler_kwargs = None,\n",
    "        **trainer_kwargs,\n",
    "    ):\n",
    "        # Protect horizon collapsed seasonality and trend NBEATSx-i basis\n",
    "        if h == 1 and ((\"seasonality\" in stack_types) or (\"trend\" in stack_types)):\n",
    "            raise Exception(\n",
    "                \"Horizon `h=1` incompatible with `seasonality` or `trend` in stacks\"\n",
    "            )\n",
    "\n",
    "        # Inherit BaseWindows class\n",
    "        super(NBEATSx, self).__init__(h=h, \n",
    "                                      input_size = input_size,\n",
    "                                      futr_exog_list=futr_exog_list,\n",
    "                                      hist_exog_list=hist_exog_list,\n",
    "                                      stat_exog_list=stat_exog_list,\n",
    "                                      exclude_insample_y=exclude_insample_y,                                      \n",
    "                                      loss=loss,\n",
    "                                      valid_loss=valid_loss,\n",
    "                                      max_steps=max_steps,\n",
    "                                      learning_rate=learning_rate,\n",
    "                                      num_lr_decays=num_lr_decays,\n",
    "                                      early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                      val_check_steps=val_check_steps,\n",
    "                                      batch_size=batch_size,\n",
    "                                      valid_batch_size=valid_batch_size,\n",
    "                                      windows_batch_size = windows_batch_size,\n",
    "                                      inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                      start_padding_enabled=start_padding_enabled,\n",
    "                                      step_size = step_size,\n",
    "                                      scaler_type=scaler_type,\n",
    "                                      num_workers_loader=num_workers_loader,\n",
    "                                      drop_last_loader=drop_last_loader,\n",
    "                                      random_seed=random_seed,\n",
    "                                      optimizer=optimizer,\n",
    "                                      optimizer_kwargs=optimizer_kwargs,\n",
    "                                      lr_scheduler=lr_scheduler,\n",
    "                                      lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
    "                                      **trainer_kwargs)\n",
    "\n",
    "        # Architecture\n",
    "        blocks = self.create_stack(\n",
    "            h=h,\n",
    "            input_size=input_size,\n",
    "            futr_input_size=self.futr_exog_size,\n",
    "            hist_input_size=self.hist_exog_size,\n",
    "            stat_input_size=self.stat_exog_size,\n",
    "            stack_types=stack_types,\n",
    "            n_blocks=n_blocks,\n",
    "            mlp_units=mlp_units,\n",
    "            dropout_prob_theta=dropout_prob_theta,\n",
    "            activation=activation,\n",
    "            shared_weights=shared_weights,\n",
    "            n_polynomials=n_polynomials,\n",
    "            n_harmonics=n_harmonics,\n",
    "        )\n",
    "        self.blocks = torch.nn.ModuleList(blocks)\n",
    "\n",
    "        # Adapter with Loss dependent dimensions\n",
    "        if self.loss.outputsize_multiplier > 1:\n",
    "            self.out = nn.Linear(\n",
    "                in_features=h, out_features=h * self.loss.outputsize_multiplier\n",
    "            )\n",
    "\n",
    "    def create_stack(\n",
    "        self,\n",
    "        h,\n",
    "        input_size,\n",
    "        stack_types,\n",
    "        n_blocks,\n",
    "        mlp_units,\n",
    "        dropout_prob_theta,\n",
    "        activation,\n",
    "        shared_weights,\n",
    "        n_polynomials,\n",
    "        n_harmonics,\n",
    "        futr_input_size,\n",
    "        hist_input_size,\n",
    "        stat_input_size,\n",
    "    ):\n",
    "        block_list = []\n",
    "        for i in range(len(stack_types)):\n",
    "            for block_id in range(n_blocks[i]):\n",
    "                # Shared weights\n",
    "                if shared_weights and block_id > 0:\n",
    "                    nbeats_block = block_list[-1]\n",
    "                else:\n",
    "                    if stack_types[i] == \"seasonality\":\n",
    "                        n_theta = (\n",
    "                            2\n",
    "                            * (self.loss.outputsize_multiplier + 1)\n",
    "                            * int(np.ceil(n_harmonics / 2 * h) - (n_harmonics - 1))\n",
    "                        )\n",
    "                        basis = SeasonalityBasis(\n",
    "                            harmonics=n_harmonics,\n",
    "                            backcast_size=input_size,\n",
    "                            forecast_size=h,\n",
    "                            out_features=self.loss.outputsize_multiplier,\n",
    "                        )\n",
    "\n",
    "                    elif stack_types[i] == \"trend\":\n",
    "                        n_theta = (self.loss.outputsize_multiplier + 1) * (\n",
    "                            n_polynomials + 1\n",
    "                        )\n",
    "                        basis = TrendBasis(\n",
    "                            degree_of_polynomial=n_polynomials,\n",
    "                            backcast_size=input_size,\n",
    "                            forecast_size=h,\n",
    "                            out_features=self.loss.outputsize_multiplier,\n",
    "                        )\n",
    "\n",
    "                    elif stack_types[i] == \"identity\":\n",
    "                        n_theta = input_size + self.loss.outputsize_multiplier * h\n",
    "                        basis = IdentityBasis(\n",
    "                            backcast_size=input_size,\n",
    "                            forecast_size=h,\n",
    "                            out_features=self.loss.outputsize_multiplier,\n",
    "                        )\n",
    "\n",
    "                    elif stack_types[i] == \"exogenous\":\n",
    "                        if futr_input_size + stat_input_size > 0:\n",
    "                            n_theta = 2*(\n",
    "                                futr_input_size + stat_input_size\n",
    "                            )\n",
    "                            basis = ExogenousBasis(forecast_size=h)\n",
    "\n",
    "                    else:\n",
    "                        raise ValueError(f\"Block type {stack_types[i]} not found!\")\n",
    "\n",
    "                    nbeats_block = NBEATSBlock(\n",
    "                        input_size=input_size,\n",
    "                        h=h,\n",
    "                        futr_input_size=futr_input_size,\n",
    "                        hist_input_size=hist_input_size,\n",
    "                        stat_input_size=stat_input_size,\n",
    "                        n_theta=n_theta,\n",
    "                        mlp_units=mlp_units,\n",
    "                        basis=basis,\n",
    "                        dropout_prob=dropout_prob_theta,\n",
    "                        activation=activation,\n",
    "                    )\n",
    "\n",
    "                # Select type of evaluation and apply it to all layers of block\n",
    "                block_list.append(nbeats_block)\n",
    "\n",
    "        return block_list\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "        # Parse windows_batch\n",
    "        insample_y = windows_batch[\"insample_y\"]\n",
    "        insample_mask = windows_batch[\"insample_mask\"]\n",
    "        futr_exog = windows_batch[\"futr_exog\"]\n",
    "        hist_exog = windows_batch[\"hist_exog\"]\n",
    "        stat_exog = windows_batch[\"stat_exog\"]\n",
    "\n",
    "        # NBEATSx' forward\n",
    "        residuals = insample_y.flip(dims=(-1,))  # backcast init\n",
    "        insample_mask = insample_mask.flip(dims=(-1,))\n",
    "\n",
    "        forecast = insample_y[:, -1:, None]  # Level with Naive1\n",
    "        block_forecasts = [forecast.repeat(1, self.h, 1)]\n",
    "        for i, block in enumerate(self.blocks):\n",
    "            backcast, block_forecast = block(\n",
    "                insample_y=residuals,\n",
    "                futr_exog=futr_exog,\n",
    "                hist_exog=hist_exog,\n",
    "                stat_exog=stat_exog,\n",
    "            )\n",
    "            residuals = (residuals - backcast) * insample_mask\n",
    "            forecast = forecast + block_forecast\n",
    "\n",
    "            if self.decompose_forecast:\n",
    "                block_forecasts.append(block_forecast)\n",
    "\n",
    "        # Adapting output's domain\n",
    "        forecast = self.loss.domain_map(forecast)\n",
    "\n",
    "        if self.decompose_forecast:\n",
    "            # (n_batch, n_blocks, h)\n",
    "            block_forecasts = torch.stack(block_forecasts)\n",
    "            block_forecasts = block_forecasts.permute(1, 0, 2, 3)\n",
    "            block_forecasts = block_forecasts.squeeze(-1)  # univariate output\n",
    "            return block_forecasts\n",
    "        else:\n",
    "            return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c57a831f-94bc-4616-b579-c114c3fc57c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBEATSx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9013b63-f65b-4a92-913c-b696e6e69914",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBEATSx.fit, name='NBEATSx.fit')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a66184ee-7a71-4598-976c-c79b83089a6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBEATSx.predict, name='NBEATSx.predict')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "360618d9-1049-4172-af97-aff19335f78b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from neuralforecast.losses.pytorch import MQLoss\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader\n",
    "from neuralforecast.utils import AirPassengersDF as Y_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fa496f4-d7f1-4fdc-87d6-ea497f853cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Month\n",
    "Y_df['month'] = Y_df['ds'].dt.month\n",
    "Y_df['year'] = Y_df['ds'].dt.year\n",
    "\n",
    "Y_train_df = Y_df[Y_df.ds<Y_df['ds'].values[-12]] # 132 train\n",
    "Y_test_df = Y_df[Y_df.ds>=Y_df['ds'].values[-12]]   # 12 test\n",
    "\n",
    "dataset, *_ = TimeSeriesDataset.from_df(df = Y_train_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat = model.predict(dataset=dataset2)\n",
    "Y_test_df['NBEATSx'] = y_hat\n",
    "\n",
    "pd.concat([Y_train_df, Y_test_df]).drop(['unique_id','month'], axis=1).set_index('ds').plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db94b63e-d82c-423f-8f75-184ae285904d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test we recover the same forecast\n",
    "y_hat2 = model.predict(dataset=dataset2)\n",
    "test_eq(y_hat, y_hat2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46090447-8e67-4f08-8a3d-9547183983f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test no leakage with test_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset, test_size=12)\n",
    "y_hat_test = model.predict(dataset=dataset, step_size=1)\n",
    "np.testing.assert_almost_equal(y_hat, y_hat_test, decimal=4)\n",
    "#test we recover the same forecast\n",
    "y_hat_test2 = model.predict(dataset=dataset, step_size=1)\n",
    "test_eq(y_hat_test, y_hat_test2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a87652b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test no leakage with test_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset, test_size=12)\n",
    "y_hat_test = model.predict(dataset=dataset, step_size=1)\n",
    "np.testing.assert_almost_equal(y_hat, y_hat_test, decimal=4)\n",
    "#test we recover the same forecast\n",
    "y_hat_test2 = model.predict(dataset=dataset, step_size=1)\n",
    "test_eq(y_hat_test, y_hat_test2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d7648e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test seasonality/trend basis protection\n",
    "test_fail(NBEATSx.__init__, \n",
    "          contains='Horizon `h=1` incompatible with `seasonality` or `trend` in stacks',\n",
    "          kwargs=dict(self=BaseWindows, h=1, input_size=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fde23c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test inference_windows_batch_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                inference_windows_batch_size=1,\n",
    "                max_steps=1)\n",
    "model.fit(dataset=dataset, test_size=12)\n",
    "y_hat_test = model.predict(dataset=dataset, step_size=1)\n",
    "#test we recover the same forecast with different inference_windows_batch_size\n",
    "model.inference_windows_batch_size=-1\n",
    "y_hat_test2 = model.predict(dataset=dataset, step_size=1)\n",
    "test_eq(y_hat_test, y_hat_test2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8501c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Check val_check_steps protection to less than max_steps\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(h=12,\n",
    "                input_size=24,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                windows_batch_size=None,\n",
    "                early_stop_patience_steps=1,\n",
    "                max_steps=1,\n",
    "                val_check_steps=5\n",
    "                )\n",
    "model.fit(dataset=dataset, test_size=12, val_size=12)\n",
    "test_eq(model.trainer_kwargs['val_check_interval'], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e8ae90",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "Y_train_df = Y_df[Y_df.ds<Y_df['ds'].values[-12]] # 132 train\n",
    "Y_test_df = Y_df[Y_df.ds>=Y_df['ds'].values[-12]]   # 12 test\n",
    "\n",
    "# Fit MQ-MLP\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "model = NBEATSx(h=12, input_size=24, max_steps=1,\n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'],\n",
    "                loss=MQLoss(level=[80, 90]))\n",
    "model.fit(dataset=dataset, val_size=12)\n",
    "\n",
    "# Parse quantile predictions\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat = model.predict(dataset=dataset2)\n",
    "Y_hat_df = pd.DataFrame.from_records(data=y_hat,\n",
    "                columns=['NBEATS'+q for q in model.loss.output_names],\n",
    "                index=Y_test_df.index)\n",
    "\n",
    "# Plot quantile predictions\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df]).drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['NBEATS-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['NBEATS-lo-90'][-12:].values, \n",
    "                 y2=plot_df['NBEATS-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0298fce5-eb13-40dc-9964-b026fd2a8928",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test validation step\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "model = NBEATSx(h=12, input_size=24, \n",
    "                windows_batch_size=None, max_steps=1, \n",
    "                scaler_type='robust',\n",
    "                stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],\n",
    "                n_blocks = [1,1,1,1],\n",
    "                futr_exog_list=['month','year'])\n",
    "model.fit(dataset=dataset, val_size=12)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat_w_val = model.predict(dataset=dataset2)\n",
    "Y_test_df['N-BEATS'] = y_hat_w_val\n",
    "\n",
    "pd.concat([Y_train_df, Y_test_df]).drop('unique_id', axis=1).set_index('ds').plot()\n",
    "plt.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f987ed0-ee6e-4f66-bd8f-96acc6fbd56c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# test no leakage with test_size and val_size\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
    "model = NBEATSx(h=12, input_size=24, windows_batch_size=None, max_steps=1,\n",
    "                scaler_type='robust',stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],n_blocks = [1,1,1,1],futr_exog_list=['month','year'])\n",
    "model.fit(dataset=dataset, val_size=12)\n",
    "dataset2 = dataset.update_dataset(dataset, Y_test_df)\n",
    "model.set_test_size(12)\n",
    "y_hat_w_val = model.predict(dataset=dataset2)\n",
    "\n",
    "dataset, *_ = TimeSeriesDataset.from_df(Y_df)\n",
    "model = NBEATSx(input_size=24, h=12, windows_batch_size=None, max_steps=1,\n",
    "                scaler_type='robust',stack_types = [\"identity\", \"trend\", \"seasonality\", \"exogenous\"],n_blocks = [1,1,1,1], futr_exog_list=['month','year'])\n",
    "model.fit(dataset=dataset, val_size=12, test_size=12)\n",
    "\n",
    "y_hat_test_w_val = model.predict(dataset=dataset, step_size=1)\n",
    "\n",
    "np.testing.assert_almost_equal(y_hat_test_w_val, y_hat_w_val, decimal=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "036be4bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# qualitative decomposition evaluation\n",
    "y_hat = model.decompose(dataset=dataset)\n",
    "\n",
    "fig, ax = plt.subplots(6, 1, figsize=(10, 15))\n",
    "\n",
    "ax[0].plot(Y_test_df['y'].values, label='True', color=\"#9C9DB2\", linewidth=4)\n",
    "ax[0].plot(y_hat.sum(axis=1).flatten(), label='Forecast', color=\"#7B3841\")\n",
    "ax[0].grid()\n",
    "ax[0].legend(prop={'size': 20})\n",
    "for label in (ax[0].get_xticklabels() + ax[0].get_yticklabels()):\n",
    "    label.set_fontsize(18)\n",
    "ax[0].set_ylabel('y', fontsize=20)\n",
    "\n",
    "ax[1].plot(y_hat[0,0], label='level', color=\"#7B3841\")\n",
    "ax[1].grid()\n",
    "ax[1].set_ylabel('Level', fontsize=20)\n",
    "\n",
    "ax[2].plot(y_hat[0,1], label='stack1', color=\"#7B3841\")\n",
    "ax[2].grid()\n",
    "ax[2].set_ylabel('Identity', fontsize=20)\n",
    "\n",
    "ax[3].plot(y_hat[0,2], label='stack2', color=\"#D9AE9E\")\n",
    "ax[3].grid()\n",
    "ax[3].set_ylabel('Trend', fontsize=20)\n",
    "\n",
    "ax[4].plot(y_hat[0,3], label='stack3', color=\"#D9AE9E\")\n",
    "ax[4].grid()\n",
    "ax[4].set_ylabel('Seasonality', fontsize=20)\n",
    "\n",
    "ax[5].plot(y_hat[0,4], label='stack4', color=\"#D9AE9E\")\n",
    "ax[5].grid()\n",
    "ax[5].set_ylabel('Exogenous', fontsize=20)\n",
    "\n",
    "ax[5].set_xlabel('Prediction \\u03C4 \\u2208 {t+1,..., t+H}', fontsize=20)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "358dd690",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68681f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import NBEATSx\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "model = NBEATSx(h=12, input_size=24,\n",
    "                #loss=MQLoss(level=[80, 90]),\n",
    "                loss=DistributionLoss(distribution='Normal', level=[80, 90]),\n",
    "                scaler_type='robust',\n",
    "                dropout_prob_theta=0.5,\n",
    "                stat_exog_list=['airline1'],\n",
    "                futr_exog_list=['trend'],\n",
    "                max_steps=200,\n",
    "                val_check_steps=10,\n",
    "                early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='M'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "Y_hat_df = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "# Plot quantile predictions\n",
    "Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['NBEATSx'], c='purple', label='mean')\n",
    "plt.plot(plot_df['ds'], plot_df['NBEATSx-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['NBEATSx-lo-90'][-12:].values, \n",
    "                 y2=plot_df['NBEATSx-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.legend()\n",
    "plt.grid()\n",
    "plt.plot()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
